{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xhhr4iSQuQq_"
   },
   "source": [
    "# Efficient SAM Segment Everything Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AIrAUKnLClPD"
   },
   "source": [
    "This script provides example for how to get segment everything visualization result from EfficientSAM using weight file.\n",
    "\n",
    "The basic method is same as SAM, we generate a grid of point prompts over the image and get the masks. Currently we directly compute all the masks in one time so it requires a large memory. If you face OOM Issue, you can consider reduce the GRID_SIZE. We will update the efficient version by calculating the mask in local crops in the future.\n",
    "\n",
    "the post processing part is from original SAM project to get a better visualization result, part of the visualization code are borrow from MobileSAM project, many thanks!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zylNfpYIuXeR"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "I64YhiKsS2KU"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torchvision.transforms import ToTensor\n",
    "from PIL import Image\n",
    "import io\n",
    "import cv2\n",
    "GRID_SIZE = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "HvPqfYv5mQdG",
    "outputId": "46844f18-2c32-4060-e9d2-080eb4ff0ce1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/facebookresearch/segment-anything.git\n",
      "  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-i4go_lep\n",
      "  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-i4go_lep\n",
      "  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf\n",
      "  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h"
     ]
    }
   ],
   "source": [
    "!pip install git+https://github.com/facebookresearch/segment-anything.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pw_4lyT8uMvy"
   },
   "source": [
    "# Segment Everything Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3178gjhfw-an"
   },
   "source": [
    "post processing script (borrow from SAM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "PxBjHaLNkkTQ"
   },
   "outputs": [],
   "source": [
    "from segment_anything.utils.amg import (\n",
    "    batched_mask_to_box,\n",
    "    calculate_stability_score,\n",
    "    mask_to_rle_pytorch,\n",
    "    remove_small_regions,\n",
    "    rle_to_mask,\n",
    ")\n",
    "from torchvision.ops.boxes import batched_nms, box_area\n",
    "def process_small_region(rles):\n",
    "        new_masks = []\n",
    "        scores = []\n",
    "        min_area = 100\n",
    "        nms_thresh = 0.7\n",
    "        for rle in rles:\n",
    "            mask = rle_to_mask(rle[0])\n",
    "\n",
    "            mask, changed = remove_small_regions(mask, min_area, mode=\"holes\")\n",
    "            unchanged = not changed\n",
    "            mask, changed = remove_small_regions(mask, min_area, mode=\"islands\")\n",
    "            unchanged = unchanged and not changed\n",
    "\n",
    "            new_masks.append(torch.as_tensor(mask).unsqueeze(0))\n",
    "            # Give score=0 to changed masks and score=1 to unchanged masks\n",
    "            # so NMS will prefer ones that didn't need postprocessing\n",
    "            scores.append(float(unchanged))\n",
    "\n",
    "        # Recalculate boxes and remove any new duplicates\n",
    "        masks = torch.cat(new_masks, dim=0)\n",
    "        boxes = batched_mask_to_box(masks)\n",
    "        keep_by_nms = batched_nms(\n",
    "            boxes.float(),\n",
    "            torch.as_tensor(scores),\n",
    "            torch.zeros_like(boxes[:, 0]),  # categories\n",
    "            iou_threshold=nms_thresh,\n",
    "        )\n",
    "\n",
    "        # Only recalculate RLEs for masks that have changed\n",
    "        for i_mask in keep_by_nms:\n",
    "            if scores[i_mask] == 0.0:\n",
    "                mask_torch = masks[i_mask].unsqueeze(0)\n",
    "                rles[i_mask] = mask_to_rle_pytorch(mask_torch)\n",
    "        masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]\n",
    "        return masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "0GFlJhrujr6v"
   },
   "outputs": [],
   "source": [
    "def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):\n",
    "    predicted_masks, predicted_iou = model(\n",
    "        img[None, ...], points, point_labels\n",
    "    )\n",
    "    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)\n",
    "    predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)\n",
    "    predicted_masks = torch.take_along_dim(\n",
    "        predicted_masks, sorted_ids[..., None, None], dim=2\n",
    "    )\n",
    "    predicted_masks = predicted_masks[0]\n",
    "    iou = predicted_iou_scores[0, :, 0]\n",
    "    index_iou = iou > 0.7\n",
    "    iou_ = iou[index_iou]\n",
    "    masks = predicted_masks[index_iou]\n",
    "    score = calculate_stability_score(masks, 0.0, 1.0)\n",
    "    score = score[:, 0]\n",
    "    index = score > 0.9\n",
    "    score_ = score[index]\n",
    "    masks = masks[index]\n",
    "    iou_ = iou_[index]\n",
    "    masks = torch.ge(masks, 0.0)\n",
    "    return masks, iou_\n",
    "\n",
    "def run_everything_ours(img_path, model):\n",
    "    model = model.cpu()\n",
    "    image = cv2.imread(image_path)\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "    img_tensor = ToTensor()(image)\n",
    "    _, original_image_h, original_image_w = img_tensor.shape\n",
    "    xy = []\n",
    "    for i in range(GRID_SIZE):\n",
    "        curr_x = 0.5 + i / GRID_SIZE * original_image_w\n",
    "        for j in range(GRID_SIZE):\n",
    "            curr_y = 0.5 + j / GRID_SIZE * original_image_h\n",
    "            xy.append([curr_x, curr_y])\n",
    "    xy = torch.from_numpy(np.array(xy))\n",
    "    points = xy\n",
    "    num_pts = xy.shape[0]\n",
    "    point_labels = torch.ones(num_pts, 1)\n",
    "    with torch.no_grad():\n",
    "      predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(\n",
    "              img_tensor.cpu(),\n",
    "              points.reshape(1, num_pts, 1, 2).cpu(),\n",
    "              point_labels.reshape(1, num_pts, 1).cpu(),\n",
    "              model.cpu(),\n",
    "          )\n",
    "    rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]\n",
    "    predicted_masks = process_small_region(rle)\n",
    "    return predicted_masks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-83WUeNPuJnT"
   },
   "source": [
    "# Visualization Related"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "QKWt76-AG31h"
   },
   "outputs": [],
   "source": [
    "def show_anns_ours(mask, ax):\n",
    "    ax.set_autoscale_on(False)\n",
    "    img = np.ones((mask[0].shape[0], mask[0].shape[1], 4))\n",
    "    img[:,:,3] = 0\n",
    "    for ann in mask:\n",
    "        m = ann\n",
    "        color_mask = np.concatenate([np.random.random(3), [0.5]])\n",
    "        img[m] = color_mask\n",
    "    ax.imshow(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GHj10cGetlGN"
   },
   "source": [
    "# Create the model and load the weights from the checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "vnzGws0DBJsj",
    "outputId": "cbe48e56-d961-4c19-e744-c257158a900c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fatal: destination path 'EfficientSAM' already exists and is not an empty directory.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/yformer/EfficientSAM.git\n",
    "import os\n",
    "os.chdir(\"EfficientSAM\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3HR4CCpYUpAI",
    "outputId": "25a3829b-7073-4bbd-8e01-a92c699deadd"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/crisi/Pattern_recognition/EfficientSAM/EfficientSAM/efficient_sam/efficient_sam.py:303: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  state_dict = torch.load(f, map_location=\"cpu\")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "EfficientSam(\n",
       "  (image_encoder): ImageEncoderViT(\n",
       "    (patch_embed): PatchEmbed(\n",
       "      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))\n",
       "    )\n",
       "    (blocks): ModuleList(\n",
       "      (0-11): 12 x Block(\n",
       "        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): Attention(\n",
       "          (qkv): Linear(in_features=384, out_features=1152, bias=True)\n",
       "          (proj): Linear(in_features=384, out_features=384, bias=True)\n",
       "        )\n",
       "        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): Mlp(\n",
       "          (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
       "          (act): GELU(approximate='none')\n",
       "          (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (neck): Sequential(\n",
       "      (0): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (1): LayerNorm2d()\n",
       "      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (3): LayerNorm2d()\n",
       "    )\n",
       "  )\n",
       "  (prompt_encoder): PromptEncoder(\n",
       "    (pe_layer): PositionEmbeddingRandom()\n",
       "    (invalid_points): Embedding(1, 256)\n",
       "    (point_embeddings): Embedding(1, 256)\n",
       "    (bbox_top_left_embeddings): Embedding(1, 256)\n",
       "    (bbox_bottom_right_embeddings): Embedding(1, 256)\n",
       "  )\n",
       "  (mask_decoder): MaskDecoder(\n",
       "    (transformer): TwoWayTransformer(\n",
       "      (layers): ModuleList(\n",
       "        (0-1): 2 x TwoWayAttentionBlock(\n",
       "          (self_attn): AttentionForTwoWayAttentionBlock(\n",
       "            (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "          )\n",
       "          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (cross_attn_token_to_image): AttentionForTwoWayAttentionBlock(\n",
       "            (q_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (k_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (v_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (out_proj): Linear(in_features=128, out_features=256, bias=True)\n",
       "          )\n",
       "          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): MLPBlock(\n",
       "            (layers): ModuleList(\n",
       "              (0): Sequential(\n",
       "                (0): Linear(in_features=256, out_features=2048, bias=True)\n",
       "                (1): GELU(approximate='none')\n",
       "              )\n",
       "            )\n",
       "            (fc): Linear(in_features=2048, out_features=256, bias=True)\n",
       "          )\n",
       "          (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "          (cross_attn_image_to_token): AttentionForTwoWayAttentionBlock(\n",
       "            (q_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (k_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (v_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (out_proj): Linear(in_features=128, out_features=256, bias=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_attn_token_to_image): AttentionForTwoWayAttentionBlock(\n",
       "        (q_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "        (k_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "        (v_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "        (out_proj): Linear(in_features=128, out_features=256, bias=True)\n",
       "      )\n",
       "      (norm_final_attn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (iou_token): Embedding(1, 256)\n",
       "    (mask_tokens): Embedding(4, 256)\n",
       "    (final_output_upscaling_layers): ModuleList(\n",
       "      (0): Sequential(\n",
       "        (0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))\n",
       "        (1): GroupNorm(1, 64, eps=1e-05, affine=True)\n",
       "        (2): GELU(approximate='none')\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))\n",
       "        (1): Identity()\n",
       "        (2): GELU(approximate='none')\n",
       "      )\n",
       "    )\n",
       "    (output_hypernetworks_mlps): ModuleList(\n",
       "      (0-3): 4 x MLPBlock(\n",
       "        (layers): ModuleList(\n",
       "          (0-1): 2 x Sequential(\n",
       "            (0): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (1): GELU(approximate='none')\n",
       "          )\n",
       "        )\n",
       "        (fc): Linear(in_features=256, out_features=32, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (iou_prediction_head): MLPBlock(\n",
       "      (layers): ModuleList(\n",
       "        (0-1): 2 x Sequential(\n",
       "          (0): Linear(in_features=256, out_features=256, bias=True)\n",
       "          (1): GELU(approximate='none')\n",
       "        )\n",
       "      )\n",
       "      (fc): Linear(in_features=256, out_features=4, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from efficient_sam.build_efficient_sam import build_efficient_sam_vits\n",
    "import zipfile\n",
    "\n",
    "with zipfile.ZipFile(\"weights/efficient_sam_vits.pt.zip\", 'r') as zip_ref:\n",
    "    zip_ref.extractall(\"weights\")\n",
    "efficient_sam_vits_model = build_efficient_sam_vits()\n",
    "efficient_sam_vits_model.eval()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b76-PTdKuidf"
   },
   "source": [
    "## Segment Everything Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o6ed2uLDCSDn"
   },
   "source": [
    "prepare your own image here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 248
    },
    "id": "Xq7KloryJrhf",
    "outputId": "8fc7e2c8-cd2c-4c48-db78-98ef9f567e75"
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(30, 30))\n",
    "image_path = \"../../input_data/0854.png\"\n",
    "image = np.array(Image.open(image_path))\n",
    "ax[0].imshow(image)\n",
    "ax[0].title.set_text(\"Original\")\n",
    "ax[0].axis('off')\n",
    "\n",
    "ax[1].imshow(image)\n",
    "mask_efficient_sam_vits = run_everything_ours(image_path, efficient_sam_vits_model)\n",
    "show_anns_ours(mask_efficient_sam_vits, ax[1])\n",
    "ax[1].title.set_text(\"EfficientSAM\")\n",
    "ax[1].axis('off')\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "machine_shape": "hm",
   "provenance": []
  },
  "custom": {
   "cells": [],
   "metadata": {
    "custom": {
     "cells": [],
     "metadata": {
      "accelerator": "GPU",
      "colab": {
       "machine_shape": "hm",
       "provenance": []
      },
      "fileHeader": "",
      "fileUid": "f337ddbb-4ec7-4bc4-8c8b-f31305249752",
      "isAdHoc": false,
      "kernelspec": {
       "display_name": "Python 3",
       "name": "python3"
      },
      "language_info": {
       "name": "python"
      }
     },
     "nbformat": 4,
     "nbformat_minor": 0
    },
    "fileHeader": "",
    "fileUid": "e9a56628-4146-43e5-85f7-3d332cf3b1a2",
    "indentAmount": 2,
    "isAdHoc": false,
    "language_info": {
     "name": "plaintext"
    }
   },
   "nbformat": 4,
   "nbformat_minor": 2
  },
  "indentAmount": 2,
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
